from utils import *
from model import *
from numpy import linalg as LA
import math
import pdb
import time

def trans_curve_Heter(d, epsilon, epoch, n, p, mus, qs):
    print('*****************************************************')
    print(f'Experiments for Transition Curve')
    print('*****************************************************')

    def forward(n, epsilon, p, q, mu):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Gaussian(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution_heter(X, A, d)
        # nonlinear convolution
        G_mp, Y_mp = message_passing_heter(X, A, p, q, d)
        return G_linear, Y_linear, G_mp, Y_mp, X, X_tilde, Node_labels, A

    results = []

    for mu in mus:
        mu = mu * np.ones(d)
        for q in qs:
            sta = time.time()
            # mu = mu * np.ones(d)
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print('mu_norm = ', LA.norm(mu))
            for j in range(epoch):
                print('Epoch: ' + str(j))
                G_linear, Y_linear, G_mp, Y_mp, X, X_tilde, Node_labels, A = forward(n, epsilon, p, q, mu)

                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
                #                 print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
                #                 print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
                #                 print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", np.std(linear_record))
            dt = time.time() - sta
            print("Dt:", dt)
            results.append([np.sqrt(d) * LA.norm(mu), np.log(p / q), original_avg, orig_std, mp_avg, mp_std, linear_avg,
                            linear_std, dt])
    np.save(f'trans_cruve_heter_{time.strftime("%m%d%y_%H%M%S")}.npy', results)

def train_laplacian(d, epsilon, epoch, n_asymptotic, bs):
    print('*****************************************************')
    print(f'Laplacian Experiment')
    print('*****************************************************')

    def parameter_setup(n):
        mu = 0.05
        # p = math.pow(math.log(n), 3) / n
        # q = 0.8 * math.pow(math.log(n), 3) / n
        p = 10 * math.pow(n, 0.5) / n
        q = 8 * math.pow(n, 0.5) / n
        return mu, p, q

    def forward(n, epsilon, p, q, mu, b):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Laplace(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_psi = psi_Laplace(X_tilde, mu, b, d)
        G_mp, Y_mp = optimal_nonlinear_propagation_general(G_psi, A, p, q)
        return G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A

    def forward_compare(n, epsilon, p, q, mu, b):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Laplace(Node_labels, d, mu)
        # Only psi
        G_psi = psi_Laplace(X_tilde, mu, b, d)
        G_mp_psi, Y_mp_psi = linear_convolution(G_psi, A, d)
        # Only phi
        G_mp_phi, Y_mp_phi = optimal_nonlinear_propagation_general(G_psi, A, p, q)
        return G_mp_psi, Y_mp_psi, G_mp_phi, Y_mp_phi, G_psi, X, X_tilde, Node_labels, A

    results = []
    for b in bs:
        for n in n_asymptotic:
            sta = time.time()
            mu, p, q = parameter_setup(n)
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print(f'b:{b}')
            print('mu_norm = ', LA.norm(mu))
            print('mean/variance = ', math.sqrt(n * (p + q)) / (n * (p - q)))
            for j in range(epoch):
                print('Epoch: ' + str(j))
                # G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A = forward(n, epsilon, p, q, mu, b)
                G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A = forward_compare(n, epsilon, p, q, mu, b)
                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
                # print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
                # print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
                # print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            dt = time.time() - sta
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", linear_std)
            print("Dt:", dt)
            results.append([LA.norm(mu), n * (p - q), b, original_avg, orig_std, mp_avg, mp_std, linear_avg, linear_std, dt])
    # np.save('Synthetic_Laplacian_compare.npy', results)

def trans_curve_Laplacian(d, epsilon, epoch, n, q, mus, ps, b = 1):
    print('*****************************************************')
    print(f'Experiments for Transition Curve')
    print('*****************************************************')

    def forward(n, epsilon, p, q, mu, b):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Laplace(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_psi = psi_Laplace(X_tilde, mu, b, d)
        G_mp, Y_mp = optimal_nonlinear_propagation_general(G_psi, A, p, q)
        return G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A

    results = []

    for mu in mus:
        for p in ps:
            sta = time.time()
            # mu = mu * np.ones(d)
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print('mu_norm = ', LA.norm(mu))
            for j in range(epoch):
                print('Epoch: ' + str(j))
                G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A = forward(n, epsilon, p, q, mu, b)

                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
                #                 print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
                #                 print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
                #                 print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", np.std(linear_record))
            dt = time.time() - sta
            print("Dt:", dt)
            results.append([np.sqrt(d) * LA.norm(mu), np.log(p / q), original_avg, orig_std, mp_avg, mp_std, linear_avg,
                            linear_std, dt])
    np.save(f'trans_cruve_Laplacian_{time.strftime("%m%d%y_%H%M%S")}.npy', results)

def train_laplacian(d, epsilon, epoch, n_asymptotic, bs):
    print('*****************************************************')
    print(f'Laplacian Experiment')
    print('*****************************************************')

    def parameter_setup(n):
        mu = 0.02
        p = math.pow(math.log(n), 3) / n
        q = 0.8 * math.pow(math.log(n), 3) / n
        return mu, p, q

    def forward(n, epsilon, p, q, mu, b):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Laplace(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_psi = psi_Laplace(X_tilde, mu, b, d)
        G_mp, Y_mp = optimal_nonlinear_propagation_general(G_psi, A, p, q)
        return G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A

    results = []
    for b in bs:
        for n in n_asymptotic:
            print(f'b:{b}')
            sta = time.time()
            mu, p, q = parameter_setup(n)
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print('mu_norm = ', LA.norm(mu))
            print('mean/variance = ', math.sqrt(n * (p + q)) / (n * (p - q)))
            for j in range(epoch):
                print('Epoch: ' + str(j))
                G_linear, Y_linear, G_mp, Y_mp, G_psi, X, X_tilde, Node_labels, A = forward(n, epsilon, p, q, mu, b)

                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
                # print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
                # print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
                # print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            dt = time.time() - sta
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", linear_std)
            print("Dt:", dt)
            results.append([LA.norm(mu), n * (p - q), original_avg, orig_std, mp_avg, mp_std, linear_avg, linear_std, dt])
    np.save('Synthetic_Laplacian.npy', results)

def train_exp4(d, epsilon, epoch, n_asymptotic):
    print('*****************************************************')
    print(f'Experiment 4')
    print('*****************************************************')

    def parameter_setup(n):
        mu = 0.02 * math.pow(math.log(n), 1/3) * np.ones(d)
        p = 0.033
        q = 0.03
        return mu, p, q

    def forward(n, epsilon, p, q, mu):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Gaussian(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_mp, Y_mp = message_passing(X, A, p, q, d)
        #   X_conv1, X_conv2, X_conv3, Y1 = one_layer_conv(n, X_tilde, A, mu, R=1, b=0)
        L = 2 * d * X
        X_mp = L + A * G_mp
        X_linear = L + A * G_linear
        return X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A

    results = []

    for n in n_asymptotic:
        sta = time.time()
        mu, p, q = parameter_setup(n)
        original_record = np.zeros(epoch)
        mp_record = np.zeros(epoch)
        linear_record = np.zeros(epoch)
        print('--------------------------------------------')
        print('--------------------------------------------')
        print('n = ', n)
        print('p = ', p, ', q = ', q)
        print('mu_norm = ', LA.norm(mu))
        print('mean/variance = ', math.sqrt(n * (p + q)) / (n * (p - q)))
        for j in range(epoch):
            print('Epoch: ' + str(j))
            X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A = forward(n, epsilon, p, q, mu)

            # Calculate the acc for original data
            acc_org = sum(np.sign(X) == Node_labels) / n
            # print("Accuracy of original data: ", acc_org)

            # Calculate the acc for non-linear aggregation data
            acc_mp = sum(Y_mp == Node_labels) / n
            # print("Accuracy of non-linear aggregation data: ", acc_mp)

            # Calculate the acc for linear aggregation data
            acc_linear = sum(Y_linear == Node_labels) / n
            # print("Accuracy of linear aggregation data: ", acc_linear)

            # record acc each time
            original_record[j] = acc_org
            mp_record[j] = acc_mp
            linear_record[j] = acc_linear
            print("------------------------------------------------")
        original_avg = sum(original_record) / epoch
        mp_avg = sum(mp_record) / epoch
        linear_avg = sum(linear_record) / epoch
        orig_std = np.std(original_record)
        mp_std = np.std(mp_record)
        linear_std = np.std(linear_record)
        dt = time.time() - sta
        print("Average acc of original:", original_avg)
        print("Average acc of mp:", mp_avg)
        print("Average acc of linear:", linear_avg)
        print("STD of original:", orig_std)
        print("STD of mp:", mp_std)
        print("STD of linear:", linear_std)
        print("Dt:", dt)
        results.append([LA.norm(mu), n * (p - q), original_avg, orig_std, mp_avg, mp_std, linear_avg, linear_std, dt])
    np.save(f'Synthetic_ReLUbetter_{time.strftime("%m%d%y_%H%M%S")}.npy', results)


def train_robustness(d, epsilon, epoch, n, q, mu, delta_mus, ps, gamma):
    print('*****************************************************')
    print(f'Experiments for Robustness')
    print('*****************************************************')
    def forward(n, epsilon, p, q, mu, delta_mu, gamma):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Gaussian_robustness(Node_labels, d, mu, delta_mu, gamma)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_mp, Y_mp = message_passing(X, A, p, q, d)
        #   X_conv1, X_conv2, X_conv3, Y1 = one_layer_conv(n, X_tilde, A, mu, R=1, b=0)
        L = 2 * d * X
        X_mp = L + A * G_mp
        X_linear = L + A * G_linear
        return X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A

    results = []

    for delta_mu in delta_mus:
        delta_result = delta_mu
        for p in ps:
            sta = time.time()
            mu = mu * np.ones(d)
            delta_mu = delta_mu * np.ones(d)
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print('delta_mu = ', delta_result)
            for j in range(epoch):
                print('Epoch: ' + str(j))
                X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A = forward(n, epsilon, p, q, mu, delta_mu, gamma)

                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
                #                 print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
                #                 print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
                #                 print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", np.std(linear_record))
            dt = time.time() - sta
            print("Dt:", dt)
            results.append([np.sqrt(d) * LA.norm(delta_mu) / LA.norm(mu), np.log(p / q), original_avg, orig_std, mp_avg, mp_std, linear_avg,
                            linear_std, dt])
        np.save('robustness3.npy', results)







def trans_curve(d, epsilon, epoch, n, q, mus, ps):
    print('*****************************************************')
    print(f'Experiments for Transition Curve')
    print('*****************************************************')
    
    def forward(n, epsilon, p, q, mu):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Gaussian(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_mp, Y_mp = message_passing(X, A, p, q, d)
        #   X_conv1, X_conv2, X_conv3, Y1 = one_layer_conv(n, X_tilde, A, mu, R=1, b=0)
        L = 2 * d * X
        X_mp = L + A * G_mp
        X_linear = L + A * G_linear
        return X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A
    
    results = []

    for mu in mus:
        for p in ps:
            sta = time.time()
            mu = mu * np.ones(d) 
            original_record = np.zeros(epoch)
            mp_record = np.zeros(epoch)
            linear_record = np.zeros(epoch)
            print('--------------------------------------------')
            print('--------------------------------------------')
            print('n = ', n)
            print('p = ', p, ', q = ', q)
            print('mu_norm = ', LA.norm(mu))
            for j in range(epoch):
                print('Epoch: ' + str(j))
                X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A = forward(n, epsilon, p, q, mu)

                # Calculate the acc for original data
                acc_org = sum(np.sign(X) == Node_labels) / n
#                 print("Accuracy of original data: ", acc_org)

                # Calculate the acc for non-linear aggregation data
                acc_mp = sum(Y_mp == Node_labels) / n
#                 print("Accuracy of non-linear aggregation data: ", acc_mp)

                # Calculate the acc for linear aggregation data
                acc_linear = sum(Y_linear == Node_labels) / n
#                 print("Accuracy of linear aggregation data: ", acc_linear)

                # record acc each time
                original_record[j] = acc_org
                mp_record[j] = acc_mp
                linear_record[j] = acc_linear
                print("------------------------------------------------")
            original_avg = sum(original_record) / epoch
            mp_avg = sum(mp_record) / epoch
            linear_avg = sum(linear_record) / epoch
            orig_std = np.std(original_record)
            mp_std = np.std(mp_record)
            linear_std = np.std(linear_record)
            print("Average acc of original:", original_avg)
            print("Average acc of mp:", mp_avg)
            print("Average acc of linear:", linear_avg)
            print("STD of original:", orig_std)
            print("STD of mp:", mp_std)
            print("STD of linear:", np.std(linear_record))
            dt = time.time() - sta
            print("Dt:", dt)
            results.append([np.sqrt(d)*LA.norm(mu), np.log(p/q), original_avg, orig_std, mp_avg, mp_std, linear_avg, linear_std, dt])
        np.save('trans_cruve.npy', results)



def trainer(d, epsilon, epoch, n_asymptotic, idx):
    # Example 1
    print('*****************************************************')
    print(f'Example {idx}')
    print('*****************************************************')

    def parameter_setup(n, idx):
        if idx == 1:
            mu = (0.05 / math.pow((math.log(n)), 1 / 5)) * np.ones(d)
            p = (math.pow(math.log(n), 2) + math.pow(math.log(n), 7 / 4)) / n
            q = (math.pow(math.log(n), 2) - math.pow(math.log(n), 7 / 4)) / n
        elif idx == 2:
            mu = 0.03 * math.log(n) / math.pow(n, 1 / 4) * np.ones(d)
            p = (2 * math.sqrt(n)) / n
            q = (math.sqrt(n)) / n
            return mu, p, q
        elif idx == 3:
            mu = 0.03 * math.sqrt(math.log(n)) / math.pow(n, 1 / 9) * np.ones(d)
            p = (math.sqrt(n) + math.pow(n, 3 / 8)) / n
            q = (math.sqrt(n) - math.pow(n, 3 / 8)) / n
        elif idx == 4:
            mu = (0.03 / math.pow((math.log(n)), 1 / 4)) * np.ones(d)
            p = (math.pow(math.log(n), 2) + math.pow(math.log(n), 15 / 8)) / n
            q = (math.pow(math.log(n), 2) - math.pow(math.log(n), 15 / 8)) / n
        else:
            raise NotImplementedError
        return mu, p, q

    def forward(n, epsilon, p, q, mu):
        A, Node_labels = CSBM(n, epsilon, p, q)
        X, X_tilde = high_dim_Gaussian(Node_labels, d, mu)
        # linear convolution
        G_linear, Y_linear = linear_convolution(X, A, d)
        # nonlinear convolution
        G_mp, Y_mp = message_passing(X, A, p, q, d)
        #   X_conv1, X_conv2, X_conv3, Y1 = one_layer_conv(n, X_tilde, A, mu, R=1, b=0)
        L = 2 * d * X
        X_mp = L + A * G_mp
        X_linear = L + A * G_linear
        return X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A

    for n in n_asymptotic:
        sta = time.time()
        mu, p, q = parameter_setup(n, idx)
        original_record = np.zeros(epoch)
        mp_record = np.zeros(epoch)
        linear_record = np.zeros(epoch)
        print('--------------------------------------------')
        print('--------------------------------------------')
        print('n = ', n)
        print('p = ', p, ', q = ', q)
        print('mu_norm = ', LA.norm(mu))
        for j in range(epoch):
            print('Epoch: ' + str(j))
            X_mp, X_linear, G_linear, Y_linear, G_mp, Y_mp, X, Node_labels, A = forward(n, epsilon, p, q, mu)

            # Calculate the acc for original data
            acc_org = sum(np.sign(X) == Node_labels) / n
            print("Accuracy of original data: ", acc_org)

            # Calculate the acc for non-linear aggregation data
            acc_mp = sum(Y_mp == Node_labels) / n
            print("Accuracy of non-linear aggregation data: ", acc_mp)

            # Calculate the acc for linear aggregation data
            acc_linear = sum(Y_linear == Node_labels) / n
            print("Accuracy of linear aggregation data: ", acc_linear)

            # record acc each time
            original_record[j] = acc_org
            mp_record[j] = acc_mp
            linear_record[j] = acc_linear
            print("------------------------------------------------")
        original_avg = sum(original_record) / epoch
        mp_avg = sum(mp_record) / epoch
        linear_avg = sum(linear_record) / epoch
        print("Average acc of original:", original_avg)
        print("Average acc of mp:", mp_avg)
        print("Average acc of linear:", linear_avg)
        print("STD of original:", np.std(original_avg))
        print("STD of mp:", np.std(mp_record))
        print("STD of linear:", np.std(linear_record))
        print("Dt:", time.time() - sta)
